{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from aco import ACO\n",
    "import numpy as np\n",
    "from scipy.spatial import distance_matrix\n",
    "import logging\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "\n",
    "N_ANTS = 30\n",
    "N_ITERATIONS = [1, 10, 30, 50, 100, 150, 200]\n",
    "\n",
    "\n",
    "\n",
    "def heuristics_reevo(edge_attr):\n",
    "    num_edges = edge_attr.shape[0]\n",
    "    num_attributes = edge_attr.shape[1]\n",
    "\n",
    "    heuristic_values = np.zeros_like(edge_attr)\n",
    "\n",
    "    # Apply feature engineering on edge attributes\n",
    "    transformed_attr = np.log1p(np.abs(edge_attr))  # Taking logarithm of absolute value of attributes\n",
    "    \n",
    "    # Normalize edge attributes\n",
    "    scaler = StandardScaler()\n",
    "    edge_attr_norm = scaler.fit_transform(transformed_attr)\n",
    "\n",
    "    # Calculate correlation coefficients\n",
    "    correlation_matrix = np.corrcoef(edge_attr_norm.T)\n",
    "\n",
    "    # Calculate heuristic value for each edge attribute\n",
    "    for i in range(num_edges):\n",
    "        for j in range(num_attributes):\n",
    "            if edge_attr_norm[i][j] != 0:\n",
    "                heuristic_values[i][j] = np.exp(-8 * edge_attr_norm[i][j] * correlation_matrix[j][j])\n",
    "\n",
    "    return heuristic_values\n",
    "\n",
    "\n",
    "def solve(node_pos):\n",
    "    dist_mat = distance_matrix(node_pos, node_pos)\n",
    "    dist_mat[np.diag_indices_from(dist_mat)] = 1 # set diagonal to a large number\n",
    "    heu = heuristics_reevo(dist_mat) + 1e-9\n",
    "    heu[heu < 1e-9] = 1e-9\n",
    "    aco = ACO(dist_mat, heu, n_ants=N_ANTS)\n",
    "    \n",
    "    results = []\n",
    "    for i in range(len(N_ITERATIONS)):\n",
    "        if i == 0:\n",
    "            obj = aco.run(N_ITERATIONS[i])\n",
    "        else:\n",
    "            obj = aco.run(N_ITERATIONS[i] - N_ITERATIONS[i-1])\n",
    "        # print(\"Iteration: {}, Objective: {}\".format(N_ITERATIONS[i], obj))\n",
    "        results.append(obj.item())\n",
    "    return results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[*] Running ...\n",
      "[*] Average for 20, 1 iterations: 3.9138639767304517\n",
      "[*] Average for 20, 10 iterations: 3.868449540571445\n",
      "[*] Average for 20, 30 iterations: 3.864481260784807\n",
      "[*] Average for 20, 50 iterations: 3.864481260784807\n",
      "[*] Average for 20, 100 iterations: 3.8639757221712103\n",
      "[*] Average for 20, 150 iterations: 3.8639757221712103\n",
      "[*] Average for 20, 200 iterations: 3.8633489771380023\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(\"[*] Running ...\")\n",
    "\n",
    "for problem_size in [20]:\n",
    "    dataset_path = f\"./dataset/val{problem_size}_dataset.npy\"\n",
    "    node_positions = np.load(dataset_path)\n",
    "    logging.info(f\"[*] Evaluating {dataset_path}\")\n",
    "    n_instances = node_positions.shape[0]\n",
    "    objs = []\n",
    "    for i, node_pos in enumerate(node_positions):\n",
    "        obj = solve(node_pos)\n",
    "        objs.append(obj)\n",
    "    # Average objective value for all instances\n",
    "    mean_obj = np.mean(objs, axis=0)\n",
    "    for i, obj in enumerate(mean_obj):\n",
    "        print(f\"[*] Average for {problem_size}, {N_ITERATIONS[i]} iterations: {obj}\")\n",
    "    print()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
